import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import os
import json
import time
import scipy
import scipy.linalg
import cvxpy as cp
import pandas as pd
from tqdm import tqdm
import warnings
from typing import Dict, Any, List, Optional, Tuple
import itertools
import math
import traceback

# Create necessary directories
os.makedirs("plots", exist_ok=True)
os.makedirs("results", exist_ok=True)

# Suppress common warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="cvxpy")
warnings.filterwarnings("ignore", category=UserWarning, message="Solution may be inaccurate.*")
warnings.filterwarnings("ignore", message="Using `tqdm.autonotebook.tqdm` in notebook mode.*", category=FutureWarning)

# =============================================================================
# IMPROVED VISUALIZATION SETTINGS
# =============================================================================
IMPROVED_STYLES = {
    'Optimal Error v_k^*': {
        'color': 'gold', 'marker': '*', 'linestyle': '-', 'label': r'Optimal $v_k^*$',
        'lw': 4.0, 'markersize': 16, 'zorder': 10, 'markeredgewidth': 1.5, 'markeredgecolor': 'black'
    },
    'Your Bound (QP CVXPY Best)': {
        'color': 'black', 'marker': 'o', 'linestyle': '-', 'label': 'Bound (QP Best)',
        'lw': 3.5, 'markersize': 12, 'zorder': 9, 'markeredgewidth': 1.0
    },
    'Your Bound (QP Analytical)': {
        'color': 'dimgrey', 'marker': '^', 'linestyle': ':', 'label': 'Bound (QP Approx)',
        'lw': 3.0, 'markersize': 12, 'zorder': 8, 'markeredgewidth': 1.0
    },
    'Your Bound (Binary)': {
        'color': 'darkgrey', 'marker': 's', 'linestyle': '--', 'label': 'Bound (Binary)',
        'lw': 3.0, 'markersize': 12, 'zorder': 7, 'markeredgewidth': 1.0
    },
    'Bound (Leverage Score Exp.)': {
        'color': 'deepskyblue', 'marker': 'D', 'linestyle': '-.', 'label': 'Bound (Lev. Score Exp.)',
        'lw': 3.0, 'markersize': 12, 'zorder': 6, 'markeredgewidth': 1.0, 'markeredgecolor': 'navy'
    },
    'Bound (Sketching Simple)': {
        'color': 'sandybrown', 'marker': 'P', 'linestyle': ':', 'label': 'Bound (Sketching Simple)',
        'lw': 3.0, 'markersize': 12, 'zorder': 5, 'markeredgewidth': 1.0, 'markeredgecolor': 'saddlebrown'
    },
    'Error Leverage Score (Actual)': {
        'color': 'blue', 'marker': 'x', 'linestyle': '-', 'label': 'Leverage Score Sampling',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 4, 'markeredgewidth': 2.0
    },
    'Error CountSketch (Actual)': {
        'color': 'orange', 'marker': 'd', 'linestyle': '--', 'label': 'CountSketch',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 3, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkorange'
    },
    'Error SRHT (Actual)': {
        'color': 'red', 'marker': 'v', 'linestyle': '-.', 'label': 'SRHT',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 2, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkred'
    },
    'Error Gaussian (Actual)': {
        'color': 'darkviolet', 'marker': '<', 'linestyle': ':', 'label': 'Gaussian Proj.',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 1, 'markeredgewidth': 1.5, 'markeredgecolor': 'indigo'
    },
    'Error Greedy OMP (Actual)': {
        'color': 'forestgreen', 'marker': '>', 'linestyle': '-', 'label': 'Greedy OMP',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 0, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkgreen'
    },
}

plt.rcParams.update({
    'font.size': 20, 'axes.titlesize': 24, 'axes.labelsize': 22,
    'xtick.labelsize': 20, 'ytick.labelsize': 20, 'legend.fontsize': 20,
    'figure.titlesize': 26, 'figure.figsize': (18, 8), 'figure.dpi': 150,
    'savefig.dpi': 300, 'lines.linewidth': 3, 'lines.markersize': 12,
    'axes.linewidth': 1.5, 'grid.linewidth': 1.0, 'axes.grid': True,
    'grid.alpha': 0.3, 'axes.titleweight': 'bold', 'axes.labelweight': 'bold',
    'figure.titleweight': 'bold', 'mathtext.default': 'regular', 'mathtext.fontset': 'cm',
})

# --- Basic Helper Functions ---
def frob_norm_sq(M: np.ndarray) -> float:
    M = np.asarray(M)
    return np.linalg.norm(M.astype(np.float64), 'fro')**2

def col_norms_sq(M: np.ndarray) -> np.ndarray:
    M = np.asarray(M)
    return np.linalg.norm(M.astype(np.float64), axis=0)**2

def calculate_rho_g(A: np.ndarray, B: np.ndarray) -> float:
    try:
        A = np.asarray(A, dtype=np.float64); B = np.asarray(B, dtype=np.float64)
        m, n = A.shape; p, n2 = B.shape
        if n != n2: raise ValueError("Dimension mismatch in calculate_rho_g (A vs B)")
        if n == 0: return 0.0
        AtA = A.T @ A; BtB = B.T @ B; G = AtA * BtB
        trace_G = np.trace(G); sum_G = np.sum(G)
        if sum_G <= 1e-12:
             if np.linalg.norm(G, 'fro') < 1e-12: return 0.0
             if trace_G > 1e-12: return np.inf
             return np.inf
        rho = trace_G / sum_G
        return max(0, rho)
    except Exception as e:
        warnings.warn(f"Error calculating Rho_G: {e}"); traceback.print_exc(); return np.nan

# --- Matrix Generation ---
def generate_matrices(m: int, p: int, n: int, cancellation_pairs: int = 0, noise_level: float = 0.0,
                      seed: Optional[int] = None, distribution: str = 'gaussian') -> Tuple[np.ndarray, np.ndarray]:
    if seed is not None: np.random.seed(seed)
    if distribution == 'gaussian': A = np.random.randn(m, n); B = np.random.randn(p, n)
    elif distribution == 'uniform': A = np.random.rand(m, n) * 2 - 1; B = np.random.rand(p, n) * 2 - 1
    else: raise ValueError("Unsupported distribution")
    if cancellation_pairs > 0 and n >= 2 * cancellation_pairs:
        cancellation_pairs = min(cancellation_pairs, n // 2)
        indices = np.random.choice(n, 2 * cancellation_pairs, replace=False)
        for i in range(cancellation_pairs):
            idx1, idx2 = indices[2*i], indices[2*i+1]
            A[:, idx2] = A[:, idx1]; B[:, idx2] = -B[:, idx1]
            scale_factor = np.random.uniform(0.5, 1.5)
            A[:, idx1] *= scale_factor; A[:, idx2] *= scale_factor
            B[:, idx1] *= scale_factor; B[:, idx2] *= scale_factor
    if noise_level > 0:
        A += np.random.normal(0, noise_level * np.std(A) if np.std(A) > 1e-9 else noise_level, size=(m, n))
        B += np.random.normal(0, noise_level * np.std(B) if np.std(B) > 1e-9 else noise_level, size=(p, n))
    return A, B

def generate_matrices_for_rho(target_rho: float, m: int, p: int, n: int, tolerance: float = 0.25,
                              max_attempts: int = 2500, base_seed: int = 0) -> Optional[Tuple[np.ndarray, np.ndarray, float]]:
    print(f"Attempting to generate matrices for target ρG ≈ {target_rho:.2f} (n={n})...")
    for attempt in range(max_attempts):
        current_seed_val = base_seed + attempt; valid_seed = current_seed_val % (2**32)
        np.random.seed(valid_seed)
        if target_rho <= 1.0:
            if np.random.rand() < 0.7: canc_pairs = 0
            else: canc_pairs = np.random.randint(0, max(1, n // 20))
            noise = np.random.uniform(0, 0.05)
        elif target_rho <= 10.0:
            canc_pairs = np.random.randint(max(1, n // 10), max(2, n // 4))
            noise = np.random.uniform(0.01, 0.1)
        else:
            canc_pairs = np.random.randint(max(1, n // 4), max(2, n // 2 + 1))
            noise = np.random.uniform(0.05, 0.15)
        A_gen, B_gen = generate_matrices(m, p, n, cancellation_pairs=canc_pairs, noise_level=noise, seed=None)
        current_rho = calculate_rho_g(A_gen, B_gen)
        if np.isnan(current_rho) or np.isinf(current_rho): continue
        if target_rho > 1e-6: is_close = abs(current_rho - target_rho) / target_rho < tolerance
        else: is_close = abs(current_rho - target_rho) < tolerance
        if is_close:
            print(f"  Success! Found ρG = {current_rho:.3f} (Target ≈ {target_rho:.2f}) after {attempt + 1} attempts.")
            return A_gen, B_gen, current_rho
    warnings.warn(f"Failed to generate matrices for target ρG ≈ {target_rho:.2f} within {max_attempts} attempts.")
    return None

# --- USER'S BOUND FUNCTION ---
def compute_theoretical_bounds(data: Dict[str, Any], k: int) -> Dict[str, Any]:
    n = data['n']; frob_norm_AB = data['frob_norm']
    binary_ratio, qp_ratio, vk_ratio, Greedy_bound_ratio = np.nan, np.nan, np.nan, np.nan
    has_qp_data = all(key in data for key in ['Gab', 'q', 'r', 'trace'])
    if has_qp_data:
        G = data['Gab']; q_vec = data['q']; r_val = data['r']; TrG = data['trace']
        oneGone = frob_norm_AB**2
        if n > 1 and oneGone > 1e-12:
            rho_G_local = TrG / oneGone if oneGone > 1e-12 else 0.0
            beta_k = (k - 1) / (n - 1) if n > 1 else 0.0
            denominator = (beta_k + (1 - beta_k) * rho_G_local)
            if abs(denominator) > 1e-12:
                gamma = 1.0 / denominator
                qp_bound_sq_ratio = max(0, 1.0 - k * gamma / n)
                qp_ratio = np.sqrt(qp_bound_sq_ratio)
            alpha_k = k / (n - 1) if n > 1 else 0.0
            binary_bound_sq = max(0, (1.0 - k * 1.0 / n) * ((1.0 - alpha_k) * oneGone + alpha_k * TrG))
            binary_ratio = np.sqrt(binary_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0
            if k > 0:
                G_hat_k = beta_k * G + (1 - beta_k) * np.diag(np.diag(G))
                y = cp.Variable(n); constraints = [y >= 0]
                objective = cp.Minimize(0.5 * cp.quad_form(y, G_hat_k) - q_vec.T @ y)
                prob = cp.Problem(objective, constraints)
                try:
                    prob.solve(solver=cp.SCS, verbose=False, eps=1e-7, max_iters=10000)
                    if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                        v_k_bound_sq = max(0, oneGone + (k / n) * 2.0 * prob.value)
                        vk_ratio = np.sqrt(v_k_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0
                except (cp.SolverError, Exception): vk_ratio = np.nan
            elif k == 0: vk_ratio = 1.0 if oneGone > 1e-12 else 0.0
    if 'A' in data and 'B' in data:
        A_data = data['A']; B_orig = data['B']
        m_A, n_A = A_data.shape; n_B, p_B = B_orig.shape
        if n_A == n and n_B == n:
            if not 0 <= k <= n: Greedy_bound_ratio = np.nan
            elif k == 0: Greedy_bound_ratio = 1.0 if frob_norm_AB > 1e-12 else 0.0
            elif k == n: Greedy_bound_ratio = 0.0
            else:
                norms_A_sq = np.sum(A_data * A_data, axis=0); norms_B_sq = np.sum(B_orig * B_orig, axis=1)
                T = norms_A_sq * norms_B_sq; indices_sorted_by_T = np.argsort(T)
                J_complement = indices_sorted_by_T[:n-k]; sum_T_complement = np.sum(T[J_complement])
                Greedy_bound_val_sq = max(0, sum_T_complement); Greedy_bound_val = np.sqrt(Greedy_bound_val_sq)
                if frob_norm_AB > 1e-12: Greedy_bound_ratio = Greedy_bound_val / frob_norm_AB
                else: Greedy_bound_ratio = 0.0 if Greedy_bound_val < 1e-12 else np.inf
    return {'binary_ratio': binary_ratio, 'qp_ratio': qp_ratio, 'qp_ratio_best': vk_ratio, 'Greedy_ratio': Greedy_bound_ratio}

# --- STANDARD BOUNDS FUNCTION ---
def compute_standard_bounds(A: np.ndarray, B: np.ndarray, k: int, frob_ABt_sq: float) -> Dict[str, float]:
    m, n = A.shape; p, n2 = B.shape
    bound_leverage_exp_sq_ratio, bound_sketching_simple_sq_ratio = np.nan, np.nan
    if n != n2 or n == 0 or frob_ABt_sq < 1e-20:
        if k == 0 and frob_ABt_sq > 1e-20: bound_leverage_exp_sq_ratio = 1.0; bound_sketching_simple_sq_ratio = np.inf
        return {'Bound (Leverage Score Exp.)': bound_leverage_exp_sq_ratio, 'Bound (Sketching Simple)': bound_sketching_simple_sq_ratio}
    if k == 0: bound_leverage_exp_sq_ratio = 1.0; bound_sketching_simple_sq_ratio = np.inf
    A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64)
    try:
        norms_A = np.linalg.norm(A_f64, axis=0); norms_B = np.linalg.norm(B_f64, axis=0)
        sum_prod_norms = np.sum(norms_A * norms_B); sum_prod_norms = max(0, sum_prod_norms)
        expected_error_sq_abs = (sum_prod_norms**2 - frob_ABt_sq) / k
        bound_leverage_exp_sq = max(0, expected_error_sq_abs)
        bound_leverage_exp_sq_ratio = bound_leverage_exp_sq / frob_ABt_sq
    except Exception as e: warnings.warn(f"Failed to compute Leverage Score bound for k={k}: {e}", RuntimeWarning)
    try:
        frob_A_sq = frob_norm_sq(A_f64); frob_B_sq = frob_norm_sq(B_f64)
        sketching_bound_sq = (frob_A_sq * frob_B_sq) / k
        bound_sketching_simple_sq_ratio = max(0, sketching_bound_sq / frob_ABt_sq)
    except Exception as e: warnings.warn(f"Failed to compute Simple Sketching bound for k={k}: {e}", RuntimeWarning)
    return {'Bound (Leverage Score Exp.)': bound_leverage_exp_sq_ratio, 'Bound (Sketching Simple)': bound_sketching_simple_sq_ratio}

# --- Algorithm Implementations ---
def run_leverage_score_sampling(A: np.ndarray, B: np.ndarray, k: int, optimal: bool = True, replacement: bool = False) -> np.ndarray:
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Matrices A and B must have the same number of columns (n)")
    if k <= 0: raise ValueError(f"k must be positive for Leverage Score Sampling, got {k}")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    if not replacement and k > n: k = n
    if k == 0: return np.zeros((m, p), dtype=A.dtype) # Should be caught by k<=0 above, but defensive.
    if optimal:
        norms_A_euc = np.linalg.norm(A, axis=0); norms_B_euc = np.linalg.norm(B, axis=0)
        lev_scores = norms_A_euc * norms_B_euc; total_lev_score = np.sum(lev_scores)
        if total_lev_score < 1e-12: probs = np.ones(n) / n if n > 0 else np.array([])
        else: probs = lev_scores / total_lev_score
    else: probs = np.ones(n) / n if n > 0 else np.array([])
    if n > 0: probs = np.maximum(probs, 1e-12); probs /= probs.sum() # Ensure positivity and normalization
    else: return np.zeros((m, p), dtype=A.dtype) # Should be caught by n==0 above
    selected_indices = np.random.choice(n, size=k, replace=replacement, p=probs)
    scaling = 1.0 / np.sqrt(k * probs[selected_indices])
    A_reduced = A[:, selected_indices] * scaling; B_reduced = B[:, selected_indices] * scaling
    return A_reduced @ B_reduced.T

def run_countsketch(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError("k must be positive for CountSketch")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    h = np.random.randint(0, k, size=n); g = np.random.choice([-1.0, 1.0], size=n)
    SA = np.zeros((m, k), dtype=A.dtype); SB = np.zeros((p, k), dtype=B.dtype)
    for j in range(n):
        hash_idx = h[j]; sign = g[j]
        SA[:, hash_idx] += sign * A[:, j]; SB[:, hash_idx] += sign * B[:, j]
    return SA @ SB.T

def run_gaussian_projection(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError("k must be positive for Gaussian Projection")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    S = np.random.randn(k, n) / np.sqrt(k)
    A_proj = A @ S.T; B_proj = B @ S.T
    return A_proj @ B_proj.T

def run_greedy_selection_omp(A: np.ndarray, B: np.ndarray, k: int, ABt_exact: Optional[np.ndarray] = None) -> np.ndarray:
    m, n = A.shape; p_dim_B, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if not (0 <= k <= n): raise ValueError(f"k={k} must be between 0 and n={n} for Greedy OMP")
    if n == 0 or k == 0: return np.zeros((m, p_dim_B), dtype=A.dtype)
    if ABt_exact is None: ABt_exact = A @ B.T
    selected_indices = []; remaining_indices = list(range(n))
    residual = ABt_exact.astype(np.float64).copy()
    A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64)
    for t in range(k):
        if not remaining_indices: break
        best_corr = -1; best_rem_idx_in_list = -1
        for list_idx, original_col_idx in enumerate(remaining_indices):
             outer_prod_j = np.outer(A_f64[:, original_col_idx], B_f64[:, original_col_idx])
             correlation = np.sum(residual * outer_prod_j); abs_correlation = np.abs(correlation)
             if abs_correlation > best_corr: best_corr = abs_correlation; best_rem_idx_in_list = list_idx
        if best_rem_idx_in_list == -1: break # Should not happen if residual is non-zero
        best_original_idx = remaining_indices.pop(best_rem_idx_in_list)
        selected_indices.append(best_original_idx)
        # Recompute residual based on current selection (OMP style)
        A_selected = A_f64[:, selected_indices]; B_selected = B_f64[:, selected_indices]
        current_approx = A_selected @ B_selected.T; residual = ABt_exact - current_approx
    if not selected_indices: return np.zeros((m, p_dim_B), dtype=A.dtype)
    A_final = A[:, selected_indices]; B_final = B[:, selected_indices]
    return A_final @ B_final.T

def fast_walsh_hadamard_transform_manual(X: np.ndarray, axis: int = -1) -> np.ndarray:
    Y = np.asarray(X, dtype=float); n_dim_fwht = Y.shape[axis]; original_axis = axis
    if axis < 0: axis = Y.ndim + axis # Convert negative axis to positive
    # Check if n_dim_fwht is a power of 2
    if not (n_dim_fwht > 0 and (n_dim_fwht & (n_dim_fwht - 1) == 0)):
        raise ValueError(f"Input size along axis {original_axis} must be a power of 2, got {n_dim_fwht}")
    if n_dim_fwht == 1: return Y # Base case: FWHT of a single element is itself
    # Recursive step
    idx_even = [slice(None)] * Y.ndim; idx_odd = [slice(None)] * Y.ndim
    idx_even[axis] = slice(None, None, 2); idx_odd[axis] = slice(1, None, 2)
    X_even = Y[tuple(idx_even)]; X_odd = Y[tuple(idx_odd)]
    H_even = fast_walsh_hadamard_transform_manual(X_even, axis=axis)
    H_odd = fast_walsh_hadamard_transform_manual(X_odd, axis=axis)
    # Combine results
    result = np.empty_like(Y)
    idx_first_half = [slice(None)] * Y.ndim; idx_second_half = [slice(None)] * Y.ndim
    idx_first_half[axis] = slice(0, n_dim_fwht // 2); idx_second_half[axis] = slice(n_dim_fwht // 2, n_dim_fwht)
    result[tuple(idx_first_half)] = H_even + H_odd; result[tuple(idx_second_half)] = H_even - H_odd
    return result

def pad_matrix(A: np.ndarray, axis: int = 1) -> Tuple[np.ndarray, int]:
    target_shape = list(A.shape); n_orig = target_shape[axis]
    if n_orig == 0: return A, 0 # Handle empty dimension
    # Calculate the next power of 2
    next_pow_2 = 1 << (n_orig - 1).bit_length() if n_orig > 0 else 0
    if next_pow_2 == 0 and n_orig > 0 : next_pow_2 = 1 # if n_orig was 1, bit_length is 0, 1<<(-1) is problematic.
    if next_pow_2 > n_orig: # Pad if next_pow_2 is greater
        pad_width = next_pow_2 - n_orig; padding_spec = [(0, 0)] * A.ndim
        padding_spec[axis] = (0, pad_width)
        A_padded = np.pad(A, pad_width=padding_spec, mode='constant', constant_values=0)
        return A_padded, n_orig
    elif next_pow_2 < n_orig : # This case implies n_orig is not a power of 2, and next_pow_2 was miscalculated or n_orig is already power of 2
        # Re-ensure next_pow_2 is correctly the smallest power of 2 >= n_orig
        if n_orig > 0 and (n_orig & (n_orig - 1) == 0): return A, n_orig # Already a power of 2
        temp_next_pow_2 = 1
        while temp_next_pow_2 < n_orig: temp_next_pow_2 *= 2
        next_pow_2 = temp_next_pow_2
        if next_pow_2 > n_orig: # Pad if needed after recalculation
            pad_width = next_pow_2 - n_orig; padding_spec = [(0, 0)] * A.ndim
            padding_spec[axis] = (0, pad_width)
            A_padded = np.pad(A, pad_width=padding_spec, mode='constant', constant_values=0)
            return A_padded, n_orig
        else: return A, n_orig # Should be equal now if padding was needed
    else: # n_orig is already a power of 2
        return A, n_orig

def run_srht_new(A: np.ndarray, B: np.ndarray, k: int, optimal_sampling: bool = False) -> np.ndarray: # optimal_sampling not used
    m, n = A.shape; p_dim_B, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError(f"k must be positive for SRHT, got {k}")
    if n == 0: return np.zeros((m, p_dim_B), dtype=A.dtype)
    A_padded, n_orig_A = pad_matrix(A, axis=1); B_padded, n_orig_B = pad_matrix(B, axis=1)
    N_padded = A_padded.shape[1]
    if N_padded != B_padded.shape[1]: # Should not happen if pad_matrix is correct
        raise RuntimeError(f"Padded dimensions mismatch for SRHT: A_padded {A_padded.shape}, B_padded {B_padded.shape}")
    if N_padded == 0: return np.zeros((m, p_dim_B), dtype=A.dtype) # If original n was 0
    k_actual = min(k, N_padded) # Cannot sample more than available dimensions
    if k_actual < k: warnings.warn(f"SRHT sampling k reduced from {k} to {k_actual} due to padded dim N={N_padded}", RuntimeWarning)
    if k_actual == 0: return np.zeros((m, p_dim_B), dtype=A.dtype) # If k was 0 or became 0
    D_diag = np.random.choice([-1.0, 1.0], size=N_padded) # Random signs
    A_signed = A_padded * D_diag; B_signed = B_padded * D_diag # Apply signs
    try: # Apply FWHT
        HA_unnorm = fast_walsh_hadamard_transform_manual(A_signed, axis=1)
        HB_unnorm = fast_walsh_hadamard_transform_manual(B_signed, axis=1)
        HA = HA_unnorm / np.sqrt(N_padded); HB = HB_unnorm / np.sqrt(N_padded) # Normalize
    except ValueError as ve: raise RuntimeError(f"Manual FWHT failed: {ve}. Input shapes: A_signed {A_signed.shape}, B_signed {B_signed.shape}")
    except Exception as e: raise RuntimeError(f"Unexpected error in Manual FWHT: {e}\n{traceback.format_exc()}")
    # Uniform sampling of columns from transformed matrices
    sampled_indices_padded = np.random.choice(N_padded, size=k_actual, replace=False)
    scaling_factor = np.sqrt(N_padded / k_actual) # Scaling for subsampling
    A_reduced = HA[:, sampled_indices_padded] * scaling_factor
    B_reduced = HB[:, sampled_indices_padded] * scaling_factor
    return A_reduced @ B_reduced.T

# --- Optimal v_k* Calculation ---
def compute_optimal_vk_star(A: np.ndarray, B: np.ndarray, k: int, ABt_exact: np.ndarray, threshold: int = 100000000):
    m, n_cols = A.shape
    A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64); ABt_exact_f64 = ABt_exact.astype(np.float64)
    frob_ABt_sq_local = frob_norm_sq(ABt_exact_f64)
    if k < 0 or k > n_cols: warnings.warn(f"k={k} is out of bounds for n_cols={n_cols}.", RuntimeWarning); return np.nan
    if k == 0: return frob_ABt_sq_local
    if k == n_cols: return 0.0
    if n_cols == 0: return frob_ABt_sq_local # No columns to select from, error is full norm
    try: num_combinations = math.comb(n_cols, k)
    except ValueError: return np.nan # k > n_cols, should be caught earlier
    if num_combinations > threshold: return np.nan # Too many combinations
    # Precompute G and RHS terms for all columns
    AtA = A_f64.T @ A_f64; BtBt = B_f64.T @ B_f64 # Note: B is (p, n), so B.T is (n, p). We need B_i B_i^T, so (B B^T)_ii.
                                            # The original code uses B (p,n) and computes B.T @ B. This is incorrect for the formula.
                                            # It should be (A_i^T A_j) * (B_i^T B_j) for G_ij.
                                            # And (A_i^T (AB^T) B_i) for RHS_i.
                                            # The code has B.T @ B, which is (n,p)x(p,n) -> (n,n). This is correct for G.
    G_full_coeffs = AtA * BtBt # Element-wise product, G_ij = (a_i^T a_j) * (b_i^T b_j)
    RHS_full_coeffs = np.zeros(n_cols, dtype=np.float64)
    for i in range(n_cols):
        a_i = A_f64[:, i]; b_i = B_f64[:, i] # a_i is (m,), b_i is (p,)
        RHS_full_coeffs[i] = a_i.T @ ABt_exact_f64 @ b_i # (1,m)x(m,p)x(p,1) -> scalar. Correct.
    min_error_sq = frob_ABt_sq_local
    for indices_tuple in itertools.combinations(range(n_cols), k):
        indices = list(indices_tuple)
        Gram_S_k = G_full_coeffs[np.ix_(indices, indices)]; RHS_S_k = RHS_full_coeffs[indices]
        try:
            w_opt = np.linalg.lstsq(Gram_S_k, RHS_S_k, rcond=None)[0]
            # Error^2 = ||AB^T||^2_F - RHS_S_k^T w_opt (or similar forms)
            current_error_sq = frob_ABt_sq_local - np.dot(RHS_S_k, w_opt) # This is ||AB^T||^2 - q^T G^{-1} q
            current_error_sq = max(0, current_error_sq) # Ensure non-negativity
        except np.linalg.LinAlgError: warnings.warn(f"Unexpected LinAlgError with lstsq for k={k}, indices={indices}. Skipping.", RuntimeWarning); continue
        if current_error_sq < min_error_sq: min_error_sq = current_error_sq
    return min_error_sq

# --- Experiment Runner ---
class ExperimentFailureError(Exception): pass

def run_experiments_flexible(A, B, k_values, n_trials=100, compute_vk_star=False, vk_star_threshold=100000,
                             compute_algorithms=True, compute_bounds=True):
    m, n = A.shape; p_dim_B = B.shape[0]
    k_values = np.array(k_values, dtype=int); k_values = np.unique(k_values[(k_values >= 0) & (k_values <= n)])
    results_init_keys = ['k', 'n', 'Your Bound (Binary)', 'Your Bound (QP Analytical)', 'Your Bound (QP CVXPY Best)',
                         'Bound (Leverage Score Exp.)', 'Bound (Sketching Simple)', 'Error Leverage Score (Actual)',
                         'Error CountSketch (Actual)', 'Error SRHT (Actual)', 'Error Gaussian (Actual)',
                         'Error Greedy OMP (Actual)', 'Frob ABT Sq', 'Rho_G']
    if compute_vk_star: results_init_keys.append('Optimal Error v_k^*')
    results = {key: (k_values if key == 'k' else (n if key == 'n' else np.full(len(k_values), np.nan)))
               for key in results_init_keys if key not in ['k', 'n', 'Frob ABT Sq', 'Rho_G']}
    results['k'] = k_values; results['n'] = n; results['Frob ABT Sq'] = np.nan; results['Rho_G'] = np.nan
    if len(k_values) == 0: warnings.warn("No valid k values."); return results
    if not any([compute_algorithms, compute_bounds, compute_vk_star]): warnings.warn("All computations disabled."); return results
    try:
        A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64)
        ABt_exact = A_f64 @ B_f64.T; frob_ABt_sq = frob_norm_sq(ABt_exact)
        frob_ABt = np.sqrt(frob_ABt_sq) if frob_ABt_sq > 1e-20 else 0.0
        results['Frob ABT Sq'] = frob_ABt_sq; results['Rho_G'] = calculate_rho_g(A_f64, B_f64)
        if frob_ABt_sq < 1e-20:
            warnings.warn(f"||AB^T||_F^2 is near zero ({frob_ABt_sq:.2e}). Errors/Bounds set to 0 or Inf.")
            for key_res in results:
                if 'Error' in key_res or 'Bound (Binary)' in key_res or 'Bound (QP' in key_res or 'Optimal' in key_res or 'Bound (Leverage' in key_res:
                    if isinstance(results[key_res], np.ndarray): results[key_res].fill(0.0)
                elif 'Bound (Sketching Simple)' in key_res:
                    if isinstance(results[key_res], np.ndarray): results[key_res].fill(np.inf)
            return results
        if n > 0:
            AtA = A_f64.T @ A_f64; BtB = B_f64.T @ B_f64; G_qp = AtA * BtB
            trace_G_qp = np.trace(G_qp); sum_G_qp = np.sum(G_qp); q_vec_qp = G_qp @ np.ones(n) # q_i = sum_j G_ij
        else: G_qp, trace_G_qp, sum_G_qp, q_vec_qp = np.array([[]]), 0.0, 0.0, np.array([])
        bound_data_dict = {'n': n, 'frob_norm': frob_ABt, 'A': A_f64, 'B': B_f64.T, # B.T is (n,p)
                           'Gab': G_qp, 'q': q_vec_qp, 'r': sum_G_qp, 'trace': trace_G_qp}
    except Exception as e: raise ExperimentFailureError(f"Initial setup (n={n}): {e}\n{traceback.format_exc()}")
    if compute_vk_star and 'Optimal Error v_k^*' in results:
        print(f"Computing Optimal v_k* (threshold={vk_star_threshold}, n={n})...")
        last_abs_vk_star_error = frob_ABt_sq # Initialize with error for k=0
        for i, k_val_vk in enumerate(tqdm(k_values, desc="Optimal v_k*", leave=False, disable=True)):
            try:
                optimal_abs_error_sq_current_k = compute_optimal_vk_star(A_f64, B_f64, k_val_vk, ABt_exact, vk_star_threshold)
                if not np.isnan(optimal_abs_error_sq_current_k):
                    # Check for non-monotonicity (allow for small numerical fluctuations)
                    if k_val_vk > 0 and optimal_abs_error_sq_current_k > last_abs_vk_star_error + 1e-9 * frob_ABt_sq :
                        warnings.warn(f"Optimal v_k* non-monotonic k={k_val_vk}: {optimal_abs_error_sq_current_k:.4e} > prev {last_abs_vk_star_error:.4e}.", RuntimeWarning)
                        # Do not update last_abs_vk_star_error if current is worse, to preserve monotonicity for next steps
                    else:
                        last_abs_vk_star_error = optimal_abs_error_sq_current_k # Update only if monotonic or first k
                    results['Optimal Error v_k^*'][i] = optimal_abs_error_sq_current_k / frob_ABt_sq
                else: # Handle NaN (e.g. threshold exceeded)
                    if i > 0 and not np.isnan(results['Optimal Error v_k^*'][i-1]): results['Optimal Error v_k^*'][i] = results['Optimal Error v_k^*'][i-1] # Propagate previous
                    elif k_val_vk == 0: results['Optimal Error v_k^*'][i] = 1.0 # Error for k=0 is 1.0 (relative)
            except Exception as e: raise ExperimentFailureError(f"v_k^* for k={k_val_vk}: {e}\n{traceback.format_exc()}")
        if 'Optimal Error v_k^*' in results: results['Optimal Error v_k^*'] = np.maximum(0, results['Optimal Error v_k^*']) # Clip at 0
    print(f"Computing Bounds and Algorithm Errors (n_trials={n_trials}, n={n}, RhoG={results['Rho_G']:.2f})...")
    for i, k_val_main in enumerate(tqdm(k_values, desc=f"Exp (n={n}, RhoG={results['Rho_G']:.2f})", leave=False, disable=True)):
        if compute_bounds:
            try:
                user_bounds = compute_theoretical_bounds(bound_data_dict, k_val_main)
                results['Your Bound (Binary)'][i] = user_bounds['binary_ratio']**2 if not np.isnan(user_bounds['binary_ratio']) else np.nan
                results['Your Bound (QP Analytical)'][i] = user_bounds['qp_ratio']**2 if not np.isnan(user_bounds['qp_ratio']) else np.nan
                results['Your Bound (QP CVXPY Best)'][i] = user_bounds['qp_ratio_best']**2 if not np.isnan(user_bounds['qp_ratio_best']) else np.nan
            except Exception as e: raise ExperimentFailureError(f"compute_theoretical_bounds k={k_val_main}: {e}\n{traceback.format_exc()}")
            try:
                std_bounds = compute_standard_bounds(A_f64, B_f64, k_val_main, frob_ABt_sq)
                results['Bound (Leverage Score Exp.)'][i] = std_bounds['Bound (Leverage Score Exp.)']
                results['Bound (Sketching Simple)'][i] = std_bounds['Bound (Sketching Simple)']
            except Exception as e: raise ExperimentFailureError(f"compute_standard_bounds k={k_val_main}: {e}\n{traceback.format_exc()}")
        if compute_algorithms:
            if k_val_main == 0: # For k=0, approximation is 0, error is ||AB^T||^2_F
                 results['Error Leverage Score (Actual)'][i] = 1.0; results['Error CountSketch (Actual)'][i] = 1.0
                 results['Error SRHT (Actual)'][i] = 1.0; results['Error Gaussian (Actual)'][i] = 1.0
                 results['Error Greedy OMP (Actual)'][i] = 1.0; continue
            errors_ls, errors_cs, errors_srht, errors_gauss = [], [], [], []
            trial_failed = False
            for _ in range(n_trials): # Run multiple trials for randomized algorithms
                try: errors_ls.append(frob_norm_sq(ABt_exact - run_leverage_score_sampling(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"LevScore trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_cs.append(frob_norm_sq(ABt_exact - run_countsketch(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"CS trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_srht.append(frob_norm_sq(ABt_exact - run_srht_new(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"SRHT trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_gauss.append(frob_norm_sq(ABt_exact - run_gaussian_projection(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"Gauss trial k={k_val_main}: {e}",RuntimeWarning); break
            if trial_failed: raise ExperimentFailureError(f"Rand algo trial failed k={k_val_main}.")
            results['Error Leverage Score (Actual)'][i] = np.mean(errors_ls) / frob_ABt_sq if errors_ls else np.nan
            results['Error CountSketch (Actual)'][i] = np.mean(errors_cs) / frob_ABt_sq if errors_cs else np.nan
            results['Error SRHT (Actual)'][i] = np.mean(errors_srht) / frob_ABt_sq if errors_srht else np.nan
            results['Error Gaussian (Actual)'][i] = np.mean(errors_gauss) / frob_ABt_sq if errors_gauss else np.nan
            # Greedy OMP is deterministic, run once
            if 1 <= k_val_main <= n: # k_val_main > 0 is already true here
                try:
                     omp_err_sq = frob_norm_sq(ABt_exact - run_greedy_selection_omp(A_f64, B_f64, k_val_main, ABt_exact))
                     results['Error Greedy OMP (Actual)'][i] = omp_err_sq / frob_ABt_sq
                except Exception as e: raise ExperimentFailureError(f"Greedy OMP k={k_val_main}: {e}\n{traceback.format_exc()}")
            elif k_val_main == 0 : results['Error Greedy OMP (Actual)'][i] = 1.0 # Should be caught earlier
            else: results['Error Greedy OMP (Actual)'][i] = np.nan # k > n
    # Post-process binary bound to ensure monotonicity (it's an upper bound)
    if 'Your Bound (Binary)' in results and len(results['Your Bound (Binary)']) > 0:
        raw_binary_bounds_sq = results['Your Bound (Binary)'].copy()
        processed_binary_bounds_sq = np.full_like(raw_binary_bounds_sq, np.nan)
        current_min_sq = np.inf
        for i_bin in range(len(raw_binary_bounds_sq)): # Iterate forwards
            if not np.isnan(raw_binary_bounds_sq[i_bin]): current_min_sq = min(current_min_sq, raw_binary_bounds_sq[i_bin])
            if np.isfinite(current_min_sq): processed_binary_bounds_sq[i_bin] = current_min_sq
            elif np.isnan(raw_binary_bounds_sq[i_bin]): processed_binary_bounds_sq[i_bin] = np.nan # Preserve NaNs if no valid min found yet
        results['Your Bound (Binary)'] = processed_binary_bounds_sq
    # Final clipping for all error/bound ratios (0 to a large number, e.g. 1e6, to avoid extreme plot scales from Inf)
    for key_final_clip in results:
         if key_final_clip not in ['k', 'n', 'Frob ABT Sq', 'Rho_G'] and isinstance(results[key_final_clip], np.ndarray):
             valid_mask = ~np.isnan(results[key_final_clip])
             if np.any(valid_mask): results[key_final_clip][valid_mask] = np.clip(results[key_final_clip][valid_mask], 0, 1e6)
    return results

# --- Experiment 4: Impact of ρG on Approximation Error ---
def run_experiment_rho_vs_error(target_rho_list, n_dim, m_dim, p_dim, k_values, n_trials, base_seed):
    print(f"\n=== Running Experiment 4: Impact of ρG on Approximation Error ===")
    print(f"Parameters: n={n_dim}, m={m_dim}, p={p_dim}")
    print(f"k values: {k_values}")
    print(f"Target ρG values: {[f'{x:.2f}' for x in target_rho_list]}")
    all_results = {}
    for rho_idx, target_rho in enumerate(target_rho_list):
        print(f"\nProcessing target ρG = {target_rho:.2f}")
        matrices_result = generate_matrices_for_rho(target_rho=target_rho, m=m_dim, p=p_dim, n=n_dim,
                                                 tolerance=0.25, max_attempts=1000, base_seed=base_seed + rho_idx*100)
        if matrices_result is None: print(f"Failed to generate matrices for target ρG={target_rho}. Skipping."); continue
        A, B, actual_rho = matrices_result
        try:
            results_single_rho = run_experiments_flexible(A=A, B=B, k_values=k_values, n_trials=n_trials,
                                                       compute_vk_star=False, # vk_star is too slow for many runs
                                                       compute_algorithms=True, compute_bounds=True)
            all_results[actual_rho] = results_single_rho
        except ExperimentFailureError as e: print(f"Experiment failed for ρG={actual_rho}: {e}"); continue
    return all_results

def plot_rho_vs_error_multi_k(results_data: Dict[float, Dict], n_dim: int, k_values: List[int], styles: Dict):
    rho_values_plot = sorted(list(results_data.keys()))
    if not rho_values_plot:
        print("No results to plot for plot_rho_vs_error_multi_k.")
        return

    fig, axes = plt.subplots(1, len(k_values), figsize=(7 * len(k_values), 7), sharey=True)
    if len(k_values) == 1: axes = [axes] # Ensure axes is always iterable

    fig.suptitle(f"Impact of structural complexity ($\\rho_G$) on approximation error (n={n_dim})",
                 fontsize=plt.rcParams['figure.titlesize'], fontweight='bold', y=1.02)

    # Determine adaptive y-limits based on actual data to be plotted (sqrt of squared errors)
    all_y_plot_values_for_lim = []
    for rho_val_plot_lim in rho_values_plot:
        result_item_lim = results_data[rho_val_plot_lim]
        for method_key_plot_lim in styles.keys():
            if method_key_plot_lim == 'Optimal Error v_k^*': continue # Skip v_k* for this plot type
            if method_key_plot_lim in result_item_lim and isinstance(result_item_lim[method_key_plot_lim], np.ndarray):
                for k_val_plot_lim in k_values: # Iterate through the k_values for subplots
                    k_idx_in_res_lim = np.where(result_item_lim['k'] == k_val_plot_lim)[0]
                    if len(k_idx_in_res_lim) > 0:
                        val_to_plot_lim = result_item_lim[method_key_plot_lim][k_idx_in_res_lim[0]]
                        # Data for y-axis is sqrt of relative squared error.
                        # Filter ensures val_to_plot_lim > 1e-12, so sqrt(val_to_plot_lim) > 1e-6.
                        if not np.isnan(val_to_plot_lim) and val_to_plot_lim > 1e-12: # Avoid log(0) or sqrt(negative)
                            all_y_plot_values_for_lim.append(np.sqrt(val_to_plot_lim))

    if all_y_plot_values_for_lim:
        y_min_data = min(all_y_plot_values_for_lim) # Known to be > 1e-6
        y_max_data = max(all_y_plot_values_for_lim) # Known to be > 1e-6

        if y_min_data == y_max_data: # Handle case where all data points are the same
            y_min_plot = y_min_data * 0.5 # y_min_data is > 1e-6, so 0.5 * y_min_data is safe and positive
            y_max_plot = y_min_data * 1.5
        else: # General case with a range of data
            y_min_plot = y_min_data * 0.8  # 20% margin below min data (will be > 0.8e-6)
            y_max_plot = y_max_data * 1.2  # 20% margin above max data

        # Ensure y_max_plot is definitely greater than y_min_plot
        if y_max_plot <= y_min_plot: y_max_plot = y_min_plot * 10 # Create a reasonable span
    else: # No data to plot, use default reasonable limits for a log scale
        y_min_plot, y_max_plot = 1e-6, 1.5 # Default if no data

    for k_idx_plot, k_val_plot_iter in enumerate(k_values):
        ax = axes[k_idx_plot]
        k_ratio_plot_iter = k_val_plot_iter / n_dim
        k_ratio_percentage_val = k_ratio_plot_iter * 100
        title_string = f"$k = {k_val_plot_iter}$ (${k_ratio_percentage_val:.1f}\\%$ of n)"
        ax.set_title(title_string, fontsize=plt.rcParams['axes.titlesize'], fontweight='bold')

        if k_idx_plot == 0:
            ax.set_ylabel("Relative Frobenius Error", fontsize=plt.rcParams['axes.labelsize'], fontweight='bold')
        ax.set_xlabel("$\\rho_G$", fontsize=plt.rcParams['axes.labelsize'], fontweight='bold')
        ax.set_xscale('log')
        ax.set_yscale('log')

        for method_key_plot_iter, style_params_iter in styles.items():
            if method_key_plot_iter == 'Optimal Error v_k^*': continue # Not plotting v_k* here

            x_coords_plot, y_coords_plot = [], []
            for rho_val_plot_iter_inner in rho_values_plot: # Iterate through actual rho values obtained
                result_inner = results_data[rho_val_plot_iter_inner]
                k_idx_res_inner = np.where(result_inner['k'] == k_val_plot_iter)[0]
                if len(k_idx_res_inner) > 0 and method_key_plot_iter in result_inner and isinstance(result_inner[method_key_plot_iter], np.ndarray):
                    val_inner = result_inner[method_key_plot_iter][k_idx_res_inner[0]]
                    if not np.isnan(val_inner) and val_inner > 1e-12: # Ensure valid for sqrt and log scale
                        x_coords_plot.append(rho_val_plot_iter_inner)
                        y_coords_plot.append(np.sqrt(val_inner)) # Plotting sqrt of squared error
            if x_coords_plot: # Only plot if there's data
                ax.plot(x_coords_plot, y_coords_plot, **style_params_iter)

        ax.grid(True, which='both', linestyle=':', alpha=0.6)
        if rho_values_plot: # Set x-limits based on the range of rho values
            ax.set_xlim(min(rho_values_plot) * 0.8, max(rho_values_plot) * 1.2)
        ax.set_ylim(y_min_plot, y_max_plot) # Apply adaptive y-limits
        ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize'])

    plt.tight_layout(rect=[0, 0, 1, 0.93]) # Adjust layout to make space for suptitle
    plot_filename_base = f"plots/experiment4_rho_vs_error_n{n_dim}"
    plt.savefig(f"{plot_filename_base}.png", dpi=plt.rcParams['savefig.dpi'], bbox_inches='tight')
    plt.savefig(f"{plot_filename_base}.pdf", bbox_inches='tight')
    plt.show()

    # Create a separate figure for the legend
    fig_legend = plt.figure(figsize=(16, 1.5)) # Adjust size as needed
    ax_legend = fig_legend.add_subplot(111)
    handles_legend, labels_legend = [], []

    for method_key_leg, style_params_leg in styles.items():
        if method_key_leg == 'Optimal Error v_k^*': continue # Don't include v_k* in this legend
        # Create dummy lines for legend handles
        line_leg, = ax_legend.plot([], [], **style_params_leg) # Use style_params directly
        handles_legend.append(line_leg)
        labels_legend.append(style_params_leg.get('label', method_key_leg)) # Get label from style_params

    # Add legend to the new figure
    ax_legend.legend(handles_legend, labels_legend, loc='center', ncol=min(5, len(handles_legend)), frameon=True, fontsize=plt.rcParams['legend.fontsize']-2)
    ax_legend.axis('off') # Turn off axis for legend-only figure
    plt.savefig(f"plots/experiment4_rho_vs_error_legend.png", dpi=plt.rcParams['savefig.dpi'], bbox_inches='tight')
    plt.savefig(f"plots/experiment4_rho_vs_error_legend.pdf", bbox_inches='tight')
    plt.close(fig_legend) # Close the legend figure
    print(f"Plots saved to {plot_filename_base}.png/pdf and plots/experiment4_rho_vs_error_legend.png/pdf")

def save_results_to_json(results_data: Dict[float, Dict], n_dim: int, k_values_list: List[int], filename_prefix: str="experiment_rho_vs_error"):
    serializable_results_data = {}
    for rho_val_json, result_json in results_data.items():
        serializable_item = {}
        for key_json, val_json in result_json.items():
            if isinstance(val_json, np.ndarray): serializable_item[key_json] = val_json.tolist()
            elif isinstance(val_json, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64,
                                     np.uint8, np.uint16, np.uint32, np.uint64)): serializable_item[key_json] = int(val_json)
            # Corrected line: removed np.float_
            elif isinstance(val_json, (np.float16, np.float32, np.float64)): serializable_item[key_json] = float(val_json)
            else: serializable_item[key_json] = val_json
        serializable_results_data[str(rho_val_json)] = serializable_item # Use string key for JSON
    k_str_json = "_".join(map(str, k_values_list))
    output_filename = f"results/{filename_prefix}_n{n_dim}_k_values_{k_str_json}.json"
    with open(output_filename, 'w') as f_json: json.dump(serializable_results_data, f_json, indent=4)
    print(f"Results saved to {output_filename}")